{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8697e357",
   "metadata": {},
   "source": [
    "# Pilpeline\n",
    "Pipeline是一种简单的方法，能让数据预处理和建模步骤一步到位。\n",
    "### Pipleline的好处：\n",
    "###### 1.更精简的代码：考虑到数据处理时会造成混乱，使用pipeline不需要在每个步骤都特别注意训练和验证数据。\n",
    "###### 2.更少的Bug：错误应用和忘记处理步骤的概率更小。\n",
    "###### 3.更易产品化：把模型转化成规模化发布原型是比较难的，但pipeline对这个有帮助。\n",
    "###### 4.模型验证更加多样化：例如交叉验证。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "25590e61",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Type</th>\n",
       "      <th>Method</th>\n",
       "      <th>Regionname</th>\n",
       "      <th>Rooms</th>\n",
       "      <th>Distance</th>\n",
       "      <th>Postcode</th>\n",
       "      <th>Bedroom2</th>\n",
       "      <th>Bathroom</th>\n",
       "      <th>Car</th>\n",
       "      <th>Landsize</th>\n",
       "      <th>BuildingArea</th>\n",
       "      <th>YearBuilt</th>\n",
       "      <th>Lattitude</th>\n",
       "      <th>Longtitude</th>\n",
       "      <th>Propertycount</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>12167</th>\n",
       "      <td>u</td>\n",
       "      <td>S</td>\n",
       "      <td>Southern Metropolitan</td>\n",
       "      <td>1</td>\n",
       "      <td>5.0</td>\n",
       "      <td>3182.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1940.0</td>\n",
       "      <td>-37.85984</td>\n",
       "      <td>144.9867</td>\n",
       "      <td>13240.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6524</th>\n",
       "      <td>h</td>\n",
       "      <td>SA</td>\n",
       "      <td>Western Metropolitan</td>\n",
       "      <td>2</td>\n",
       "      <td>8.0</td>\n",
       "      <td>3016.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>193.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>-37.85800</td>\n",
       "      <td>144.9005</td>\n",
       "      <td>6380.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8413</th>\n",
       "      <td>h</td>\n",
       "      <td>S</td>\n",
       "      <td>Western Metropolitan</td>\n",
       "      <td>3</td>\n",
       "      <td>12.6</td>\n",
       "      <td>3020.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>555.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>-37.79880</td>\n",
       "      <td>144.8220</td>\n",
       "      <td>3755.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2919</th>\n",
       "      <td>u</td>\n",
       "      <td>SP</td>\n",
       "      <td>Northern Metropolitan</td>\n",
       "      <td>3</td>\n",
       "      <td>13.0</td>\n",
       "      <td>3046.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>265.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1995.0</td>\n",
       "      <td>-37.70830</td>\n",
       "      <td>144.9158</td>\n",
       "      <td>8870.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6043</th>\n",
       "      <td>h</td>\n",
       "      <td>S</td>\n",
       "      <td>Western Metropolitan</td>\n",
       "      <td>3</td>\n",
       "      <td>13.3</td>\n",
       "      <td>3020.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>673.0</td>\n",
       "      <td>673.0</td>\n",
       "      <td>1970.0</td>\n",
       "      <td>-37.76230</td>\n",
       "      <td>144.8272</td>\n",
       "      <td>4217.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      Type Method             Regionname  Rooms  Distance  Postcode  Bedroom2  \\\n",
       "12167    u      S  Southern Metropolitan      1       5.0    3182.0       1.0   \n",
       "6524     h     SA   Western Metropolitan      2       8.0    3016.0       2.0   \n",
       "8413     h      S   Western Metropolitan      3      12.6    3020.0       3.0   \n",
       "2919     u     SP  Northern Metropolitan      3      13.0    3046.0       3.0   \n",
       "6043     h      S   Western Metropolitan      3      13.3    3020.0       3.0   \n",
       "\n",
       "       Bathroom  Car  Landsize  BuildingArea  YearBuilt  Lattitude  \\\n",
       "12167       1.0  1.0       0.0           NaN     1940.0  -37.85984   \n",
       "6524        2.0  1.0     193.0           NaN        NaN  -37.85800   \n",
       "8413        1.0  1.0     555.0           NaN        NaN  -37.79880   \n",
       "2919        1.0  1.0     265.0           NaN     1995.0  -37.70830   \n",
       "6043        1.0  2.0     673.0         673.0     1970.0  -37.76230   \n",
       "\n",
       "       Longtitude  Propertycount  \n",
       "12167    144.9867        13240.0  \n",
       "6524     144.9005         6380.0  \n",
       "8413     144.8220         3755.0  \n",
       "2919     144.9158         8870.0  \n",
       "6043     144.8272         4217.0  "
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 加载数据 X_train、X_valid、y_train和y_valid\n",
    "import pandas as pd \n",
    "from sklearn.model_selection import train_test_split\n",
    "melb_data = pd.read_csv('../melb_data.csv')\n",
    "y = melb_data.loc[:,'Price']\n",
    "X_col = ['Type','Method','Regionname','Rooms','Distance','Postcode',\n",
    "       'Bedroom2','Bathroom','Car','Landsize','BuildingArea','YearBuilt',\n",
    "       'Lattitude','Longtitude','Propertycount']\n",
    "\n",
    "X = melb_data.loc[:,X_col]\n",
    "X_train,X_valid,y_train,y_valid = train_test_split(X,y,test_size=0.2,train_size=0.8,random_state=0)\n",
    "X_train.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65e917bf",
   "metadata": {},
   "source": [
    "#### 步骤1：定义处理步骤\n",
    "###### 填充缺失数据为数值类型（用均值等方法填充数值类型数据）\n",
    "###### 用one-hot编码来填充分类缺失数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "ec96e2d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.compose import ColumnTransformer\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.impute import SimpleImputer\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "\n",
    "# 预处理数值类型数据\n",
    "numerical_transformer = SimpleImputer(strategy='constant')\n",
    "\n",
    "# 预处理分类数据\n",
    "categorical_transformer = Pipeline(steps=[\n",
    "    ('imputer', SimpleImputer(strategy='most_frequent')),\n",
    "    ('onehot', OneHotEncoder(handle_unknown='ignore'))\n",
    "])\n",
    "numerical_cols = [col for col in X_train.columns if X_train[col].dtype in ['int64', 'float64']]\n",
    "categorical_cols = [col for col in X_train.columns if X_train[col].nunique() < 10 and X_train[col].dtype == 'object']\n",
    "# 将处理步骤打包\n",
    "preprocessor = ColumnTransformer(\n",
    "    transformers=[\n",
    "        ('num', numerical_transformer, numerical_cols),\n",
    "        ('cat', categorical_transformer, categorical_cols)\n",
    "    ])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f4846b4",
   "metadata": {},
   "source": [
    "#### 步骤2：定义模型\n",
    "###### 接下来我们使用熟悉的RandomForestRegressor类定义一个随机森林模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "6f575e0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.ensemble import RandomForestRegressor\n",
    "\n",
    "model = RandomForestRegressor(n_estimators=100, random_state=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c40fb8f0",
   "metadata": {},
   "source": [
    "#### 步骤3：创建和评估Pipeline\n",
    "###### 使用pipeline，我们预处理训练数据以及拟合模型只有了一行代码。（相反，如果不使用pipeline，我们需要做填充、编码、和模型训练的步骤。如果我们需要同时处理数值和分类变量，这将非常繁琐！）\n",
    "###### 我们在调用predict()指令时使用的X_valid中包含未经处理的特征值，pipeline会在预测前自动进行预处理。（然而，如果不使用pipeline，在预测前，我们需要记住对验证数据进行预处理）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "3ad1bc97",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MAE: 160679.18917034855\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import mean_absolute_error\n",
    "\n",
    "# 在pipeline中打包预处理和建模代码\n",
    "my_pipeline = Pipeline(steps=[('preprocessor', preprocessor),\n",
    "                              ('model', model)\n",
    "                             ])\n",
    "\n",
    "# 预处理训练数据，拟合模型 \n",
    "my_pipeline.fit(X_train, y_train)\n",
    "\n",
    "# 预处理验证数据, 获取预测值\n",
    "preds = my_pipeline.predict(X_valid)\n",
    "\n",
    "# 评估模型\n",
    "score = mean_absolute_error(y_valid, preds)\n",
    "print('MAE:', score)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7ecf93a",
   "metadata": {},
   "source": [
    "Pipeline在机器学习代码清理和规避错误中，尤其是复杂数据预处理的工作流，非常实用。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
